# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" """
import mindspore.ops as ops
import mindspore.nn as nn

class LayerBase(nn.Cell):
    def __init__(self, in_channels, out_channels):
        super(LayerBase, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3)

    def construct(self, x):
        x = self.conv1(x)
        x = nn.ReLU(x)
        x = self.conv2(x)
        x = nn.ReLU(x)
        return x

class Encoder(nn.Cell):
    def __init__(self):
        super(Encoder, self).__init__();
        self.encode = nn.MaxPool2d(2)

    def construct(self, x):
        x = self.encode(x)
        return x

class Decoder(nn.Cell):
    def __init__(self, in_channels):
        super(Decoder, self).__init__();
        self.decode = nn.Conv2dTranspose(in_channels, in_channels // 2, kernel_size=2, stride=2)

    def construct(self, x, origin_x):
        x = self.decode(x)
        dx = origin_x.size()[3] - x.size()[3]
        dy = origin_x.size()[2] - x.size()[2]
        x = nn.Pad((dx // 2, dx - dx // 2), (dy // 2, dy - dy // 2))
        x = ops.Concat((origin_x, x))
        return x


class Unet(nn.Cell):
    def __init__(self, n_class):
        super(Unet, self).__init__()
        self.lhc1 = LayerBase(1, 64)
        self.lhc2 = LayerBase(64, 128)
        self.lhc3 = LayerBase(128, 256)
        self.lhc4 = LayerBase(256, 512)

        self.bottom = LayerBase(512, 1024)

        self.rhc1 = LayerBase(1024, 512)
        self.rhc2 = LayerBase(512, 256)
        self.rhc3 = LayerBase(256, 128)
        self.rhc4 = LayerBase(128, 64)

        self.encoder = Encoder()
        self.decoder1 = Decoder(1024)
        self.decoder2 = Decoder(512)
        self.decoder3 = Decoder(256)
        self.decoder4 = Decoder(128)

        self.output = nn.Conv2d(64, 2, 1)

    def forward(self, x):
        l1 = self.lhc1(x)
        d1 = self.encoder(l1)
        l2 = self.lhc2(d1)
        d2 = self.encoder(l2)
        l3 = self.lhc3(d2)
        d3 = self.encoder(l3)
        l4 = self.lhc4(d3)
        d4 = self.encoder(l4)

        bottom = self.bottom(d4)

        r1 = self.decoder1(bottom, l4)
        r1 = self.rhc1(r1)
        r2 = self.decoder2(r1, l3)
        r2 = self.rhc2(r2)
        r3 = self.decoder3(r2, l2)
        r3 = self.rhc3(r3)
        r4 = self.decoder4(r3, l1)
        r4 = self.lhc4(r4)
        r4 = self.output(r4)

        return r4

